import torch
import megatron

from megatron.core.parallel_state import (
    get_tensor_model_parallel_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size
)
from megatron.core.tensor_parallel.utils import VocabUtility


def vocab_parallel_cross_entropy_forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
    # Maximum value along vocab dimension across all GPUs.
    logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
    torch.distributed.all_reduce(logits_max,
                                 op=torch.distributed.ReduceOp.MAX,
                                 group=get_tensor_model_parallel_group())
    # Subtract the maximum value.
    vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)

    # Get the partition's vocab indecies
    get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
    partition_vocab_size = vocab_parallel_logits.size()[-1]
    rank = get_tensor_model_parallel_rank()
    world_size = get_tensor_model_parallel_world_size()
    vocab_start_index, vocab_end_index = get_vocab_range(
        partition_vocab_size, rank, world_size)

    # Create a mask of valid vocab ids (1 means it needs to be masked).
    target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
    masked_target = target.clone() - vocab_start_index
    masked_target *= ~target_mask

    # Get predicted-logits = logits[target].
    # For Simplicity, we convert logits to a 2-D tensor with size
    # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
    logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
    masked_target_1d = masked_target.view(-1)
    arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
                             device=logits_2d.device)
    predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
    predicted_logits_1d = predicted_logits_1d.clone().contiguous()
    predicted_logits = predicted_logits_1d.view_as(target)
    predicted_logits *= ~target_mask
    # All reduce is needed to get the chunks from other GPUs.
    torch.distributed.all_reduce(predicted_logits,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=get_tensor_model_parallel_group())

    # Sum of exponential of logits along vocab dimension across all GPUs.
    exp_logits = vocab_parallel_logits
    torch.exp(vocab_parallel_logits, out=exp_logits)
    sum_exp_logits = exp_logits.sum(dim=-1)
    torch.distributed.all_reduce(sum_exp_logits,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=get_tensor_model_parallel_group())

    loss = torch.log(sum_exp_logits) - predicted_logits

    # Normalize and optionally smooth logits
    exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

    vocab_size = exp_logits.size(-1)
    if label_smoothing > 0:
        """
        We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
        = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
        = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
        = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
        = (K * (1 - alpha) - 1) / (K - 1)) * y_gt  + (alpha / (K - 1)) * \sum_{i} y_i
        = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
        """
        if label_smoothing >= 1.0:
            raise ValueError("label_smoothing value should in (0,1)")
        smoothing = label_smoothing * vocab_size / (vocab_size - 1)

        # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
        log_probs = torch.log(exp_logits)
        mean_log_probs = log_probs.mean(dim=-1)
        loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs

    ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size

    # Store softmax, target-mask and masked-target for backward pass.
    ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

    return loss

